import torch
import numpy as np
from instance import *
import collections
from matplotlib import pyplot as plt
from model.models import *
from complementary_loss import *
import random
import argparse
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
parser = argparse.ArgumentParser(description='IDCLL for mnist,kmnist and fmnist')

parser.add_argument('-lr', '--learning_rate', help='optimizer\'s learning rate', default=1e-3, type=float)
parser.add_argument('-bs', '--batch_size', help='batch_size of ordinary labels.', default=256, type=int)
parser.add_argument('-lo', '--loss', help='loss type', choices=['forward_loss', 'scl_nl', 'scl_exp', 'pc_loss','w_loss','porden','nn','ovr_loss'], type=str, required=True)
parser.add_argument('-da', '--data', help='data type', choices=['mnist', 'kmnist', 'fmnist'], type=str, required=True)
parser.add_argument('-k', '--k', help='mink', default=3, type=int)
parser.add_argument('-e', '--epochs', help='number of epochs', type=int, default=100)
parser.add_argument('-wd', '--weight_decay', help='weight decay', default=1e-4, type=float)
parser.add_argument('-se', '--seed', help='seed', default=1, type=int)
args = parser.parse_args()
def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True
setup_seed(args.seed)
# evaluation
def train_accuracy(eval_train_loader, model):
    model.eval()
    total, num_samples = 0, 0
    for id, (train_X, train_Y) in enumerate(eval_train_loader):
        train_X, train_Y = train_X.to(device), train_Y.to(device)
        batch_size = train_X.shape[0] 
        train_X = train_X.reshape(batch_size, 1, 28, 28)
        outputs,feat = model(train_X)
        _, predicted = torch.max(outputs, 1)
        total += (predicted == train_Y).sum().item()
        num_samples += train_Y.size(0)
    return round(100*total/num_samples, 2)

def test_accuracy(test_loader, model):
    model.eval()
    total, num_samples = 0, 0
    for id, (test_X, test_Y) in enumerate(test_loader):
        test_X, test_Y = test_X.to(device), test_Y.to(device)
        batch_size = test_X.shape[0] 
        test_X = test_X.reshape(batch_size, 1, 28, 28)
        outputs,feat = model(test_X)
        _, predicted = torch.max(outputs, 1)
        total += (predicted == test_Y).sum().item()
        num_samples += test_Y.size(0)
    return round(100*total/num_samples, 2)
# train
def train(train_loader, model, optimizer, epoch, lr, me,ccp_com,ccp_mincom):
    train_loss = 0
    total = 0
    model.train()
    for idx, (x, y,mincom,com) in enumerate(train_loader):
        x, y,mincom,com = x.to(device), y.to(device), mincom.to(device), com.to(device)
        batch_size = x.shape[0] 
        x = x.reshape(batch_size, 1, 28, 28)
        outputs,feat = model(x)
        if args.loss=='forward_loss':
            loss = forward_loss(outputs,10,mincom)
        elif args.loss=='scl_nl':
            loss = scl_nl(outputs,mincom)
        elif args.loss=='scl_exp':
            loss = scl_exp(outputs,mincom)
        elif args.loss=='pc_loss':
            loss = pc_loss(outputs,10,mincom)
        elif args.loss=='w_loss':
            loss = w_loss(outputs, 10,mincom)
        elif args.loss == 'porden':
            loss = partial_loss(outputs,mincom)
        elif args.loss == 'nn':
            loss = non_negative_loss(outputs,10,mincom,ccp_mincom,beta=0)
        elif args.loss == 'ovr_loss':
            loss = ovr_loss(outputs,mincom)
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()
        train_loss = train_loss + loss.item()
        total+=1
    return train_loss/total
train_loader, eval_train_loader, test_loader,ccp_com,ccp_mincom = load_instance_dependent_dataloader(args.batch_size,args.data,args.k,'resnet18')
model = cnn_mnist().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate, weight_decay=args.weight_decay)
best_acc = 0
for i in range(args.epochs):
    loss= train(train_loader, model, optimizer, i, args.learning_rate,me=args.loss,ccp_com=ccp_com,ccp_mincom=ccp_mincom)
    model.eval()
    test_nat_acc = test_accuracy(test_loader, model)
    train_nat_acc = train_accuracy(eval_train_loader, model)
    if test_nat_acc>best_acc:
        best_acc = test_nat_acc
    print("epoch:",i,";","test_acc:",test_nat_acc,"train_acc:",train_nat_acc,"best_acc:",best_acc)
